import torch
from torch import nn


if __name__ == "__main__":
    conv = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0)

    input_data = torch.randn((1, 1, 128, 128))
    print(input_data.shape)

    out = conv(input_data)

    print(out.shape)
